{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "71a329f0",
   "metadata": {},
   "source": [
    "# Zero shot object detection instruction for using LMI container on SageMaker\n",
    "In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "- S3 bucket push access\n",
    "- SageMaker access\n",
    "\n",
    "## Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67fa3208",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker boto3 awscli --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9ac353",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.djl_inference.model import DJLModel\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81deac79",
   "metadata": {},
   "source": [
    "## Step 2: Start preparing model artifacts\n",
    "In LMI contianer, we expect some artifacts to help setting up the model\n",
    "- serving.properties (optional): Defines the model server settings\n",
    "- model.py (required): A python file to define the core inference logic\n",
    "- requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b011bf5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile serving.properties\n",
    "engine=Python\n",
    "# enable dynamic server side batch\n",
    "# batch_size=5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d6798b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "\n",
    "\n",
    "import logging\n",
    "\n",
    "import requests\n",
    "import torch\n",
    "from PIL import Image\n",
    "from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection\n",
    "\n",
    "from djl_python import Input\n",
    "from djl_python import Output\n",
    "\n",
    "\n",
    "class ZeroShotObjectDetection(object):\n",
    "\n",
    "    def __init__(self):\n",
    "        self.device = None\n",
    "        self.model = None\n",
    "        self.processor = None\n",
    "        self.initialized = False\n",
    "\n",
    "    def initialize(self, properties: dict):\n",
    "        \"\"\"\n",
    "        Initialize model.\n",
    "        \"\"\"\n",
    "        model_id = \"IDEA-Research/grounding-dino-base\"\n",
    "        device_id = properties.get(\"device_id\", \"-1\")\n",
    "        device_id = \"cpu\" if device_id == \"-1\" else \"cuda:\" + device_id\n",
    "        self.device = torch.device(device_id)\n",
    "        self.processor = AutoProcessor.from_pretrained(model_id)\n",
    "        self.model = AutoModelForZeroShotObjectDetection.from_pretrained(\n",
    "            model_id).to(self.device)\n",
    "        self.initialized = True\n",
    "\n",
    "    def inference(self, inputs):\n",
    "        outputs = Output()\n",
    "        try:\n",
    "            batch = inputs.get_batches()\n",
    "            images = []\n",
    "            text = []\n",
    "            sizes = []\n",
    "            for i, item in enumerate(batch):\n",
    "                data = item.get_as_json()\n",
    "                data = data.pop(\"inputs\", data)\n",
    "                image = Image.open(\n",
    "                    requests.get(data[\"image_url\"][\"url\"], stream=True).raw)\n",
    "                images.append(image)\n",
    "                text.append(data[\"text\"])\n",
    "                sizes.append(image.size[::-1])\n",
    "\n",
    "            model_inputs = self.processor(images=images,\n",
    "                                          text=text,\n",
    "                                          return_tensors=\"pt\").to(self.device)\n",
    "            with torch.no_grad():\n",
    "                model_outputs = self.model(**model_inputs)\n",
    "\n",
    "            results = self.processor.post_process_grounded_object_detection(\n",
    "                model_outputs,\n",
    "                model_inputs.input_ids,\n",
    "                box_threshold=0.4,\n",
    "                text_threshold=0.3,\n",
    "                target_sizes=sizes)\n",
    "            for i, result in enumerate(results):\n",
    "                ret = {\n",
    "                    \"labels\": result[\"labels\"],\n",
    "                    \"scores\": result[\"scores\"].tolist(),\n",
    "                    \"boxes\": result[\"boxes\"].cpu().detach().numpy().tolist(),\n",
    "                }\n",
    "                if inputs.is_batch():\n",
    "                    outputs.add_as_json(ret, batch_index=i)\n",
    "                else:\n",
    "                    outputs.add_as_json(ret)\n",
    "        except Exception as e:\n",
    "            logging.exception(\"ZeroShotObjectDetection inference failed\")\n",
    "            # error handling\n",
    "            outputs = Output().error(str(e))\n",
    "\n",
    "        return outputs\n",
    "\n",
    "\n",
    "_service = ZeroShotObjectDetection()\n",
    "\n",
    "\n",
    "def handle(inputs: Input):\n",
    "    \"\"\"\n",
    "    Default handler function\n",
    "    \"\"\"\n",
    "    if not _service.initialized:\n",
    "        # stateful model\n",
    "        _service.initialize(inputs.get_properties())\n",
    "\n",
    "    if inputs.is_empty():\n",
    "        # initialization request\n",
    "        return None\n",
    "\n",
    "    return _service.inference(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8b50a6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%writefile requirements.txt\n",
    "# Start writing content here (remove this file if not neeed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0142973",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "mv model.py mymodel/\n",
    "# mv requirements.txt mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e58cf33",
   "metadata": {},
   "source": [
    "## Step 3: Upload model artifact to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b1e5ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "bucket = session.default_bucket()  # default bucket to host artifacts\n",
    "code_artifact = session.upload_data(\"mymodel.tar.gz\", bucket, \"lmi-model\")\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "004f39f6",
   "metadata": {},
   "source": [
    "## Step 4: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch\n",
    "\n",
    "### Getting the container image URI (optional)\n",
    "\n",
    "Check out available images: [Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d38284-1069-4503-801b-1254626bb730",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose a specific version of LMI image directly:\n",
    "# image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79970e03-b49b-4363-99f0-8c90367a9c55",
   "metadata": {},
   "source": [
    "### Create SageMaker model\n",
    "\n",
    "Here we are using [LMI PySDK](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html) to create the model.\n",
    "\n",
    "Checkout more [configuration options](https://docs.djl.ai/docs/serving/serving/docs/lmi/deployment_guide/configurations.html#environment-variable-configurations)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d382700-f219-4d2a-b726-21f77b9c0958",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DJLModel(\n",
    "    model_data=code_artifact,\n",
    "    #image_uri=image_uri,          # choose a specific version of LMI DLC image\n",
    "    role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e247e78c-460f-42a9-bd21-606a9b21b012",
   "metadata": {},
   "source": [
    "### Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e0e61cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "instance_type = \"ml.g4dn.2xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "\n",
    "predictor = model.deploy(initial_instance_count=1,\n",
    "    instance_type=instance_type,\n",
    "    endpoint_name=endpoint_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb63ee65",
   "metadata": {},
   "source": [
    "## Step 5: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bcef095",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\n",
    "        \"text\": \"a cat. a remote control.\",\n",
    "        \"image_url\": {\n",
    "            \"url\": \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
    "        }\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1cd9042",
   "metadata": {},
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d674b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "session.delete_endpoint(endpoint_name)\n",
    "session.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95ef658c-bcf6-40d3-aca5-da8d5e8f323a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
